import os
import pandas as pd
from torchvision.io import read_video
from tqdm import tqdm
def get_label(dataframe, wrong_paths):
    df = pd.DataFrame(columns = ['path', 'label'])
    for path in wrong_paths:
        wonrg_row = dataframe[dataframe[0]==path]
        new_data = {'path': wonrg_row[0].item(), 'label': wonrg_row[1].item()}
        df = pd.concat([df, pd.DataFrame([new_data])], ignore_index=True)
    return df

def main():
    train = pd.read_csv('/data/hahmwj/Merge_adapter_v4/dataset/annotations/k400/train.csv', header = None, delimiter = ' ')
    val = pd.read_csv('/data/hahmwj/Merge_adapter_v4/dataset/annotations/k400/val.csv', header = None, delimiter = ' ')
    test = pd.read_csv('/data/hahmwj/Merge_adapter_v4/dataset/annotations/k400/test.csv', header = None, delimiter = ' ')


    # 인덱스를 재설정 (선택 사항)
    test = test.reset_index(drop=True)
    error = []
    for path in tqdm(test[0]):
        if os.path.getsize('/data/hahmwj/dataset/k400/test/' + path) < 1 * 1024:
            error.append((path))
            
        a = read_video('/data/hahmwj/dataset/k400/test/' + path, 0, 2, pts_unit="sec")
        if a[0].size(0) == 0:
            error.append((path))

    #         break
    # error = pd.DataFrame({'load_error_path':error})
    # error.to_csv('/data/hahmwj/Merge_adapter_v4/dataset/annotations/k400/error.csv')

    print(len(test))
    print(len(error))
    test = test[~test[0].isin(error)]
    print(len(test))
    
    # wrong = pd.read_csv('/data/hahmwj/Merge_adapter_v4/dataset/annotations/k400/wrong.csv')

    # train_wrong = wrong[wrong['type']=='train']['path']
    # val_wrong = wrong[wrong['type']=='val']['path']
    # test_wrong = wrong[wrong['type']=='test']['path']
    
    # train_wrong = get_label(train, train_wrong)
    # val_wrong = get_label(val, val_wrong)
    # test_wrong = get_label(test, test_wrong)

    # train = train[~train[0].isin(train_wrong['path'])]
    # val = val[~val[0].isin(val_wrong['path'])]
    # test = test[~test[0].isin(test_wrong['path'])]
    
    # wrong = pd.concat([train_wrong, val_wrong, test_wrong])

    # wrong.to_csv('/data/hahmwj/Merge_adapter_v4/dataset/annotations/k400/wrong1.csv', index = False)

    # train.to_csv('/data/hahmwj/Merge_adapter_v4/dataset/annotations/k400/train.csv', header = None, sep = ' ', index=False)
    # val.to_csv('/data/hahmwj/Merge_adapter_v4/dataset/annotations/k400/val.csv', header = None, sep = ' ', index=False)
    test.to_csv('/data/hahmwj/Merge_adapter_v4/dataset/annotations/k400/test.csv', header = None, sep = ' ', index=False)
    

if __name__ == '__main__': main()